{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f6787afb",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 多头注意力\n",
    ":label:`sec_multihead-attention`\n",
    "\n",
    "在实践中，当给定相同的查询、键和值的集合时，\n",
    "我们希望模型可以基于相同的注意力机制学习到不同的行为，\n",
    "然后将不同的行为作为知识组合起来，\n",
    "捕获序列内各种范围的依赖关系\n",
    "（例如，短距离依赖和长距离依赖关系）。\n",
    "因此，允许注意力机制组合使用查询、键和值的不同\n",
    "*子空间表示*（representation subspaces）可能是有益的。\n",
    "\n",
    "为此，与其只使用单独一个注意力汇聚，\n",
    "我们可以用独立学习得到的$h$组不同的\n",
    "*线性投影*（linear projections）来变换查询、键和值。\n",
    "然后，这$h$组变换后的查询、键和值将并行地送到注意力汇聚中。\n",
    "最后，将这$h$个注意力汇聚的输出拼接在一起，\n",
    "并且通过另一个可以学习的线性投影进行变换，\n",
    "以产生最终输出。\n",
    "这种设计被称为*多头注意力*（multihead attention）\n",
    " :cite:`Vaswani.Shazeer.Parmar.ea.2017`。\n",
    "对于$h$个注意力汇聚输出，每一个注意力汇聚都被称作一个*头*（head）。\n",
    " :numref:`fig_multi-head-attention`\n",
    "展示了使用全连接层来实现可学习的线性变换的多头注意力。\n",
    "\n",
    "![多头注意力：多个头连结然后线性变换](../img/multi-head-attention.svg)\n",
    ":label:`fig_multi-head-attention`\n",
    "\n",
    "## 模型\n",
    "\n",
    "在实现多头注意力之前，让我们用数学语言将这个模型形式化地描述出来。\n",
    "给定查询$\\mathbf{q} \\in \\mathbb{R}^{d_q}$、\n",
    "键$\\mathbf{k} \\in \\mathbb{R}^{d_k}$和\n",
    "值$\\mathbf{v} \\in \\mathbb{R}^{d_v}$，\n",
    "每个注意力头$\\mathbf{h}_i$（$i = 1, \\ldots, h$）的计算方法为：\n",
    "\n",
    "$$\\mathbf{h}_i = f(\\mathbf W_i^{(q)}\\mathbf q, \\mathbf W_i^{(k)}\\mathbf k,\\mathbf W_i^{(v)}\\mathbf v) \\in \\mathbb R^{p_v},$$\n",
    "\n",
    "其中，可学习的参数包括\n",
    "$\\mathbf W_i^{(q)}\\in\\mathbb R^{p_q\\times d_q}$、\n",
    "$\\mathbf W_i^{(k)}\\in\\mathbb R^{p_k\\times d_k}$和\n",
    "$\\mathbf W_i^{(v)}\\in\\mathbb R^{p_v\\times d_v}$，\n",
    "以及代表注意力汇聚的函数$f$。\n",
    "$f$可以是 :numref:`sec_attention-scoring-functions`中的\n",
    "加性注意力和缩放点积注意力。\n",
    "多头注意力的输出需要经过另一个线性转换，\n",
    "它对应着$h$个头连结后的结果，因此其可学习参数是\n",
    "$\\mathbf W_o\\in\\mathbb R^{p_o\\times h p_v}$：\n",
    "\n",
    "$$\\mathbf W_o \\begin{bmatrix}\\mathbf h_1\\\\\\vdots\\\\\\mathbf h_h\\end{bmatrix} \\in \\mathbb{R}^{p_o}.$$\n",
    "\n",
    "基于这种设计，每个头都可能会关注输入的不同部分，\n",
    "可以表示比简单加权平均值更复杂的函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23568774",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:54:02.679903Z",
     "iopub.status.busy": "2022-12-07T16:54:02.679345Z",
     "iopub.status.idle": "2022-12-07T16:54:04.937385Z",
     "shell.execute_reply": "2022-12-07T16:54:04.936235Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e067395",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "## 实现\n",
    "\n",
    "在实现过程中通常[**选择缩放点积注意力作为每一个注意力头**]。\n",
    "为了避免计算代价和参数代价的大幅增长，\n",
    "我们设定$p_q = p_k = p_v = p_o / h$。\n",
    "值得注意的是，如果将查询、键和值的线性变换的输出数量设置为\n",
    "$p_q h = p_k h = p_v h = p_o$，\n",
    "则可以并行计算$h$个头。\n",
    "在下面的实现中，$p_o$是通过参数`num_hiddens`指定的。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "59bfd2ef",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:54:04.941737Z",
     "iopub.status.busy": "2022-12-07T16:54:04.941358Z",
     "iopub.status.idle": "2022-12-07T16:54:04.951008Z",
     "shell.execute_reply": "2022-12-07T16:54:04.949891Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\"多头注意力\"\"\"\n",
    "    def __init__(self, key_size, query_size, value_size, num_hiddens,\n",
    "                 num_heads, dropout, bias=False, **kwargs):\n",
    "        super(MultiHeadAttention, self).__init__(**kwargs)\n",
    "        self.num_heads = num_heads\n",
    "        self.attention = d2l.DotProductAttention(dropout)\n",
    "        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n",
    "        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n",
    "        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n",
    "        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n",
    "\n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        # queries，keys，values的形状:\n",
    "        # (batch_size，查询或者“键－值”对的个数，num_hiddens)\n",
    "        # valid_lens　的形状:\n",
    "        # (batch_size，)或(batch_size，查询的个数)\n",
    "        # 经过变换后，输出的queries，keys，values　的形状:\n",
    "        # (batch_size*num_heads，查询或者“键－值”对的个数，\n",
    "        # num_hiddens/num_heads)\n",
    "        queries = transpose_qkv(self.W_q(queries), self.num_heads)\n",
    "        keys = transpose_qkv(self.W_k(keys), self.num_heads)\n",
    "        values = transpose_qkv(self.W_v(values), self.num_heads)\n",
    "\n",
    "        if valid_lens is not None:\n",
    "            # 在轴0，将第一项（标量或者矢量）复制num_heads次，\n",
    "            # 然后如此复制第二项，然后诸如此类。\n",
    "            valid_lens = torch.repeat_interleave(\n",
    "                valid_lens, repeats=self.num_heads, dim=0)\n",
    "\n",
    "        # output的形状:(batch_size*num_heads，查询的个数，\n",
    "        # num_hiddens/num_heads)\n",
    "        output = self.attention(queries, keys, values, valid_lens)\n",
    "\n",
    "        # output_concat的形状:(batch_size，查询的个数，num_hiddens)\n",
    "        output_concat = transpose_output(output, self.num_heads)\n",
    "        return self.W_o(output_concat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7651da0a",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "为了能够[**使多个头并行计算**]，\n",
    "上面的`MultiHeadAttention`类将使用下面定义的两个转置函数。\n",
    "具体来说，`transpose_output`函数反转了`transpose_qkv`函数的操作。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b7330027",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:54:04.954339Z",
     "iopub.status.busy": "2022-12-07T16:54:04.953881Z",
     "iopub.status.idle": "2022-12-07T16:54:04.960681Z",
     "shell.execute_reply": "2022-12-07T16:54:04.959640Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def transpose_qkv(X, num_heads):\n",
    "    \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n",
    "    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)\n",
    "    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，\n",
    "    # num_hiddens/num_heads)\n",
    "    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n",
    "\n",
    "    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数,\n",
    "    # num_hiddens/num_heads)\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "\n",
    "    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数,\n",
    "    # num_hiddens/num_heads)\n",
    "    return X.reshape(-1, X.shape[2], X.shape[3])\n",
    "\n",
    "\n",
    "#@save\n",
    "def transpose_output(X, num_heads):\n",
    "    \"\"\"逆转transpose_qkv函数的操作\"\"\"\n",
    "    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "    return X.reshape(X.shape[0], X.shape[1], -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b6aff10",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "下面使用键和值相同的小例子来[**测试**]我们编写的`MultiHeadAttention`类。\n",
    "多头注意力输出的形状是（`batch_size`，`num_queries`，`num_hiddens`）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "51deccc4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:54:04.964215Z",
     "iopub.status.busy": "2022-12-07T16:54:04.963461Z",
     "iopub.status.idle": "2022-12-07T16:54:04.990943Z",
     "shell.execute_reply": "2022-12-07T16:54:04.989832Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiHeadAttention(\n",
       "  (attention): DotProductAttention(\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       "  (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_hiddens, num_heads = 100, 5\n",
    "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
    "                               num_hiddens, num_heads, 0.5)\n",
    "attention.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9658ae9f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:54:04.994206Z",
     "iopub.status.busy": "2022-12-07T16:54:04.993742Z",
     "iopub.status.idle": "2022-12-07T16:54:05.006054Z",
     "shell.execute_reply": "2022-12-07T16:54:05.004979Z"
    },
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 100])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size, num_queries = 2, 4\n",
    "num_kvpairs, valid_lens =  6, torch.tensor([3, 2])\n",
    "X = torch.ones((batch_size, num_queries, num_hiddens))\n",
    "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n",
    "attention(X, Y, Y, valid_lens).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec18ac5b",
   "metadata": {
    "origin_pos": 22
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 多头注意力融合了来自于多个注意力汇聚的不同知识，这些知识的不同来源于相同的查询、键和值的不同的子空间表示。\n",
    "* 基于适当的张量操作，可以实现多头注意力的并行计算。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 分别可视化这个实验中的多个头的注意力权重。\n",
    "1. 假设有一个完成训练的基于多头注意力的模型，现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43f4552c",
   "metadata": {
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/5758)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}